{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Federated Learning - MNIST Example</h1>\n",
    "<h2>Populate remote grid nodes with labeled tensors </h2>\n",
    "In this notebook, we will populate our grid nodes with labeled data so that it will be used later by people interested in train models.\n",
    "\n",
    "**NOTE:** At the time of running this notebook, we were running the grid components in background mode.  \n",
    "\n",
    "Components:\n",
    " - Grid Gateway(http://localhost:5000)\n",
    " - Grid Node Bob (http://localhost:3000)\n",
    " - Grid Node Alice (http://localhost:3001)\n",
    " \n",
    "This notebook was made based on <a href=\"https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb\">Part 10: Federated Learning with Encrypted Gradient Aggregation</a> tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Import dependencies</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "from syft.workers.node_client import NodeClient\n",
    "import torch\n",
    "import pickle\n",
    "import time\n",
    "import torchvision\n",
    "from torchvision import datasets, transforms\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Setup config</h2>\n",
    "Init hook, connect with grid nodes, etc..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)\n",
    "\n",
    "# Connect directly to grid nodes\n",
    "nodes = [\"ws://localhost:3000/\",\n",
    "         \"ws://localhost:3001/\"]\n",
    "\n",
    "compute_nodes = []\n",
    "for node in nodes:\n",
    "    compute_nodes.append( NodeClient(hook, node) )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 - Load Dataset\n",
    "\n",
    "The code below will load and preprocess an N amount of MNIST data samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_SAMPLES = 10000\n",
    "MNIST_PATH = './dataset'\n",
    "\n",
    "transform = transforms.Compose([\n",
    "                              transforms.ToTensor(),\n",
    "                              transforms.Normalize((0.1307,), (0.3081,)),\n",
    "                              ])\n",
    "\n",
    "trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)\n",
    "\n",
    "dataiter = iter(trainloader)\n",
    "\n",
    "images_train_mnist, labels_train_mnist = dataiter.next()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>2 - Split dataset </h2>\n",
    "We will split our dataset to send to nodes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_mnist = torch.split(images_train_mnist, int(len(images_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)\n",
    "labels_mnist = torch.split(labels_train_mnist, int(len(labels_train_mnist) / len(compute_nodes)), dim=0 )  #tuple of chunks (labels / number of nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>3 - Tagging tensors</h2>\n",
    "The code below will add a tag (of your choice) to the data that will be sent to grid nodes. This tag is important as the gateway will need it to retrieve this data later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag_img = []\n",
    "tag_label = []\n",
    "\n",
    "\n",
    "for i in range(len(compute_nodes)):\n",
    "    tag_img.append(datasets_mnist[i].tag(\"#X\", \"#mnist\", \"#dataset\").describe(\"The input datapoints to the MNIST dataset.\"))\n",
    "    tag_label.append(labels_mnist[i].tag(\"#Y\", \"#mnist\", \"#dataset\").describe(\"The input labels to the MNIST dataset.\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2> 4 - Sending our tensors to grid nodes</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "shared_x1 = tag_img[0].send(compute_nodes[0]) # First chunk of dataset to Bob\n",
    "shared_x2 = tag_img[1].send(compute_nodes[1]) # Second chunk of dataset to Alice\n",
    "\n",
    "shared_y1 = tag_label[0].send(compute_nodes[0]) # First chunk of labels to Bob\n",
    "shared_y2 = tag_label[1].send(compute_nodes[1]) # Second chunk of labels to Alice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X tensor pointers:  (Wrapper)>[PointerTensor | me:48155914346 -> Bob:84790492900]\n",
      "\tTags: #mnist #X #dataset \n",
      "\tShape: torch.Size([5000, 1, 28, 28])\n",
      "\tDescription: The input datapoints to the MNIST dataset.... (Wrapper)>[PointerTensor | me:14916691261 -> Alice:53510693506]\n",
      "\tTags: #mnist #X #dataset \n",
      "\tShape: torch.Size([5000, 1, 28, 28])\n",
      "\tDescription: The input datapoints to the MNIST dataset....\n",
      "Y tensor pointers:  (Wrapper)>[PointerTensor | me:34933432192 -> Bob:96713703699]\n",
      "\tTags: #mnist #Y #dataset \n",
      "\tShape: torch.Size([5000])\n",
      "\tDescription: The input labels to the MNIST dataset.... (Wrapper)>[PointerTensor | me:47811294626 -> Alice:47171429524]\n",
      "\tTags: #mnist #Y #dataset \n",
      "\tShape: torch.Size([5000])\n",
      "\tDescription: The input labels to the MNIST dataset....\n"
     ]
    }
   ],
   "source": [
    "print(\"X tensor pointers: \", shared_x1, shared_x2)\n",
    "print(\"Y tensor pointers: \", shared_y1, shared_y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Disconnect nodes</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(compute_nodes)):\n",
    "    compute_nodes[i].close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
